{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "0db242ba",
   "metadata": {},
   "source": [
    "# Knowledge graph completion with PyKeen and Neo4j\n",
    "## Integrate PyKeen library with Neo4j for multi-class link prediction using knowledge graph embedding models\n",
    "\n",
    "A couple of weeks ago, I met Francois Vanderseypen, a Graph Data Science consultant. We decided to join forces and start a Graph Machine learning blog series. This blog post will present how to perform knowledge graph completion, which is simply a multi-class link prediction. Instead of just predicting is a link, we are also trying to predict its type.\n",
    "\n",
    "For knowledge graph completion, the underlying graph should contain multiple types of relationships. Otherwise, if you are dealing with only a single kind of relationship, you can use the standard link prediction techniques that do not consider the relationship type. The example visualization has only a single node type, but in practice, your input graph can consists of multiple node types as well.\n",
    "\n",
    "We have to use the knowledge graph embedding models for a multi-class link prediction pipeline instead of plain node embedding models.\n",
    "What's the difference, you may ask.\n",
    "While node embedding models embed only nodes, the knowledge graph embedding models embed both nodes and relationships.\n",
    "\n",
    "The standard syntax to describe the pattern is that the starting node is called head (h), the end or target node is referred to as tail (t), and the relationship is r.\n",
    "The intuition behind the knowledge graph embedding model such as TransE is that the embedding of the head plus the relationship is close to the embedding of the tail if the relationship is present.\n",
    "\n",
    "The predictions are then quite simple. For example, if you want to predict new relationships for a specific node, you just sum the node plus the relationship embedding and evaluate if any of the nodes are near the embedding sum.\n",
    "\n",
    "# Prepare the data in Neo4j Desktop\n",
    "\n",
    "To follow along with this tutorial, I recommend you download the Neo4j Desktop application.\n",
    "\n",
    "Once you have installed the Neo4j Desktop, you can download the database dump and use it to restore a database instance. https://drive.google.com/file/d/1u34cFBYvBtdBsqOUPdmbcIyIt88IiZYe/view?usp=sharing\n",
    "\n",
    "Our subset of the Hetionet graph contains genes, compounds, and diseases. There are many relationships between them, and you would probably need to be in the biomedical domain to understand them, so I won't go into details.\n",
    "In our case, the most important relationship is the treats relationship between compounds and diseases. This blog post will use the knowledge graph embedding models to predict new treats relationships. You could think of this scenario as a drug repurposing task.\n",
    "# PyKeen\n",
    "PyKeen is an incredible, simple-to-use library that can be used for knowledge graph completion tasks.\n",
    "Currently, it features 35 knowledge graph embedding models and even supports out-of-the-box hyper-parameter optimizations.\n",
    "I like it due to its high-level interface, making it very easy to construct a PyKeen graph and train an embedding model.\n",
    "\n",
    "# Transform a Neo4j to a PyKeen graph\n",
    "Now we will move on to the practical part of this post.\n",
    "First, we will transform the Neo4j graph to the PyKeen graph and split the train-test data. To begin, we have to define the connection to the Neo4j database."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "ee9a1139",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Define Neo4j connections\n",
    "from neo4j import GraphDatabase\n",
    "import pandas as pd\n",
    "\n",
    "host = 'bolt://localhost:7687'\n",
    "user = 'neo4j'\n",
    "password = 'letmein'\n",
    "driver = GraphDatabase.driver(host,auth=(user, password))\n",
    "                                         \n",
    "\n",
    "def run_query(query, params={}):\n",
    "    with driver.session() as session:\n",
    "        result = session.run(query, params)\n",
    "        return pd.DataFrame([r.values() for r in result], columns=result.keys())"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "079ff947",
   "metadata": {},
   "source": [
    "The `run_query` function executes a Cypher query and returns the output in the form of a Pandas dataframe. The PyKeen library has a `from_labeled_triples` that takes a list of triples as an input and constructs a graph from it."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "562a480b",
   "metadata": {},
   "outputs": [],
   "source": [
    "data = run_query(\"\"\"\n",
    "MATCH (s)-[r]->(t)\n",
    "RETURN toString(id(s)) as source, toString(id(t)) AS target, type(r) as type\n",
    "\"\"\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "a5dcda05",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>source</th>\n",
       "      <th>target</th>\n",
       "      <th>type</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>0</td>\n",
       "      <td>12590</td>\n",
       "      <td>interacts</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>0</td>\n",
       "      <td>8752</td>\n",
       "      <td>interacts</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>0</td>\n",
       "      <td>7915</td>\n",
       "      <td>interacts</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>0</td>\n",
       "      <td>21711</td>\n",
       "      <td>interacts</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>0</td>\n",
       "      <td>6447</td>\n",
       "      <td>interacts</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "  source target       type\n",
       "0      0  12590  interacts\n",
       "1      0   8752  interacts\n",
       "2      0   7915  interacts\n",
       "3      0  21711  interacts\n",
       "4      0   6447  interacts"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "data.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "045f9596",
   "metadata": {},
   "outputs": [],
   "source": [
    "from pykeen.triples import TriplesFactory\n",
    "\n",
    "\n",
    "tf = TriplesFactory.from_labeled_triples(\n",
    "  data[[\"source\", \"type\", \"target\"]].values,\n",
    "  create_inverse_triples=False,\n",
    "  entity_to_id=None,\n",
    "  relation_to_id=None,\n",
    "  compact_id=False,\n",
    "  filter_out_candidate_inverse_relations=True,\n",
    "  metadata=None,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f424bd14",
   "metadata": {},
   "source": [
    "This example has a generic Cypher query that can be used to fetch any Neo4j dataset and construct a PyKeen from it. Notice that we use the internal Neo4j ids of nodes to build the triples data frame. For some reason, the PyKeen library expects the triple elements to be all strings, so we simply cast the internal ids to string.\n",
    "Now that we have our PyKeen graph, we can use the split method to perform the train-test data split."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "7d77b130",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "using automatically assigned random_state=3324760580\n"
     ]
    }
   ],
   "source": [
    "training, testing, validation = tf.split([.8, .1, .1])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7fcdfcf6",
   "metadata": {},
   "source": [
    "It couldn't get any easier than this. I must congratulate the PyKeen authors for developing such a straightforward interface.\n",
    "# Train a knowledge graph embedding model\n",
    "Now that we have the train-test data available, we can go ahead and train a knowledge graph embedding model. We will use the RotatE model in this example. I am not that familiar with all the variations of the embedding models, but if you want to learn more, I would suggest the lecture by Jure Leskovec I linked above.\n",
    "We won't perform any hyper-parameter optimization to keep the tutorial simple. I've chosen to use 20 epochs and defined the dimension size to be 512."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "e47868ea",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "WARNING:pykeen.utils:No cuda devices were available. The model runs on CPU\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "721c39869f154a6399e00c30d54cff41",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Training epochs on cpu:   0%|          | 0/20 [00:00<?, ?epoch/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Training batches on cpu:   0%|          | 0/1756 [00:00<?, ?batch/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Training batches on cpu:   0%|          | 0/1756 [00:00<?, ?batch/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Training batches on cpu:   0%|          | 0/1756 [00:00<?, ?batch/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Training batches on cpu:   0%|          | 0/1756 [00:00<?, ?batch/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Training batches on cpu:   0%|          | 0/1756 [00:00<?, ?batch/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Training batches on cpu:   0%|          | 0/1756 [00:00<?, ?batch/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Training batches on cpu:   0%|          | 0/1756 [00:00<?, ?batch/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Training batches on cpu:   0%|          | 0/1756 [00:00<?, ?batch/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Training batches on cpu:   0%|          | 0/1756 [00:00<?, ?batch/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Training batches on cpu:   0%|          | 0/1756 [00:00<?, ?batch/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:pykeen.evaluation.evaluator:Currently automatic memory optimization only supports GPUs, but you're using a CPU. Therefore, the batch_size will be set to the default value.\n",
      "INFO:pykeen.evaluation.evaluator:No evaluation batch_size provided. Setting batch_size to '32'.\n",
      "INFO:pykeen.evaluation.evaluator:Evaluation took 2050.90s seconds\n",
      "INFO:pykeen.training.training_loop:=> Saved checkpoint after having finished epoch 10.\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Training batches on cpu:   0%|          | 0/1756 [00:00<?, ?batch/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Training batches on cpu:   0%|          | 0/1756 [00:00<?, ?batch/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Training batches on cpu:   0%|          | 0/1756 [00:00<?, ?batch/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Training batches on cpu:   0%|          | 0/1756 [00:00<?, ?batch/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Training batches on cpu:   0%|          | 0/1756 [00:00<?, ?batch/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Training batches on cpu:   0%|          | 0/1756 [00:00<?, ?batch/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Training batches on cpu:   0%|          | 0/1756 [00:00<?, ?batch/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Training batches on cpu:   0%|          | 0/1756 [00:00<?, ?batch/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Training batches on cpu:   0%|          | 0/1756 [00:00<?, ?batch/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Training batches on cpu:   0%|          | 0/1756 [00:00<?, ?batch/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:pykeen.evaluation.evaluator:Currently automatic memory optimization only supports GPUs, but you're using a CPU. Therefore, the batch_size will be set to the default value.\n",
      "INFO:pykeen.evaluation.evaluator:No evaluation batch_size provided. Setting batch_size to '32'.\n",
      "INFO:pykeen.evaluation.evaluator:Evaluation took 2317.22s seconds\n",
      "INFO:pykeen.training.training_loop:=> Saved checkpoint after having finished epoch 20.\n",
      "INFO:pykeen.training.training_loop:=> loading checkpoint '/tmp/tmpx04jauw1'\n",
      "INFO:pykeen.training.training_loop:=> loaded checkpoint '/tmp/tmpx04jauw1' stopped after having finished epoch 20\n",
      "INFO:pykeen.evaluation.evaluator:Currently automatic memory optimization only supports GPUs, but you're using a CPU. Therefore, the batch_size will be set to the default value.\n",
      "INFO:pykeen.evaluation.evaluator:No evaluation batch_size provided. Setting batch_size to '32'.\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "f20cb75303534b58bd53154208c2987a",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Evaluating on cpu:   0%|          | 0.00/56.2k [00:00<?, ?triple/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:pykeen.evaluation.evaluator:Evaluation took 2556.08s seconds\n"
     ]
    }
   ],
   "source": [
    "from pykeen.pipeline import pipeline\n",
    "\n",
    "result = pipeline(\n",
    "    training=training,\n",
    "    testing=testing,\n",
    "    validation=validation,\n",
    "    model='RotatE',\n",
    "    stopper='early',\n",
    "    epochs=20,\n",
    "    dimensions=512,\n",
    "    random_seed=420\n",
    "\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b01aecd6",
   "metadata": {},
   "source": [
    "# Multi-class link prediction\n",
    "The PyKeen library supports multiple methods for multi-class link prediction.\n",
    "You could find the top K predictions in the network, or you can be more specific and define a particular head node and relationship type and evaluate if there are any new connections predicted.\n",
    "\n",
    "In this example, you will predict new treats relationships for the L-Asparagine compound. Because we used the internal node ids for mapping, we first have to retrieve the node id of L-Asparagine from Neo4j and input it into the prediction method."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "3f870831",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "       tail_id tail_label     score  in_training\n",
      "13279    13279        322 -7.856274        False\n",
      "5671      5671      15821 -7.987956        False\n",
      "3561      3561      13667 -8.157429        False\n",
      "11437    11437      21700 -8.158304        False\n",
      "17674    17674       7714 -8.204805        False\n"
     ]
    }
   ],
   "source": [
    "from pykeen.models.predict import get_tail_prediction_df\n",
    "\n",
    "compound_id = run_query(\"\"\"\n",
    "MATCH (s:Compound)\n",
    "WHERE s.name = \"L-Asparagine\"\n",
    "RETURN toString(id(s)) as id\n",
    "\"\"\")['id'][0]\n",
    "\n",
    "\n",
    "df = get_tail_prediction_df(result.model, compound_id, 'treats', triples_factory=result.training)\n",
    "print(df.head(5))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b94d0907",
   "metadata": {},
   "source": [
    "# Store predictions to Neo4j\n",
    "For easier evaluation of the results, we will store the top five predictions back to Neo4j."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "cb6e1fc5",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "Empty DataFrame\n",
       "Columns: []\n",
       "Index: []"
      ]
     },
     "execution_count": 19,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "candidate_nodes = df[df['in_training'] == False].head(5)['tail_label'].to_list()\n",
    "\n",
    "run_query(\"\"\"\n",
    "MATCH (n)\n",
    "WHERE id(n) = toInteger($compound_id)\n",
    "UNWIND $candidates as ca\n",
    "MATCH (c)\n",
    "WHERE id(c) = toInteger(ca)\n",
    "MERGE (n)-[:PREDICTED_TREATS]->(c)\n",
    "\"\"\", {'compound_id':compound_id, 'candidates': candidate_nodes})\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "53963721",
   "metadata": {},
   "source": [
    "# Inspect results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "1bee3736",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>compound</th>\n",
       "      <th>disease</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>L-Asparagine</td>\n",
       "      <td>Crohn's disease</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>L-Asparagine</td>\n",
       "      <td>hematologic cancer</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>L-Asparagine</td>\n",
       "      <td>colon cancer</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>L-Asparagine</td>\n",
       "      <td>stomach cancer</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>L-Asparagine</td>\n",
       "      <td>chronic obstructive pulmonary disease</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "       compound                                disease\n",
       "0  L-Asparagine                        Crohn's disease\n",
       "1  L-Asparagine                     hematologic cancer\n",
       "2  L-Asparagine                           colon cancer\n",
       "3  L-Asparagine                         stomach cancer\n",
       "4  L-Asparagine  chronic obstructive pulmonary disease"
      ]
     },
     "execution_count": 20,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "run_query(\"\"\"\n",
    "MATCH (c:Compound)-[:PREDICTED_TREATS]->(d:Disease)\n",
    "RETURN c.name as compound, d.name as disease\n",
    "\"\"\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2483793b",
   "metadata": {},
   "source": [
    "# Explaining predictions\n",
    "As far as I know, the knowledge graph embedding model is not that useful for explaining predictions. On the other hand, you could use the existing connections in the graph to present the information to a medical doctor and let him decide if the predictions make sense or not.\n",
    "For example, you could investigate direct and indirect paths between L-Asparagine and colon cancer with the following Cypher query."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "id": "6455bab4",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>[n in nodes(p) | n.name]</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>[L-Asparagine, ASRGL1, SSBP2, colon cancer]</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>[L-Asparagine, SLC38A3, PLXNA1, colon cancer]</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>[L-Asparagine, ASRGL1, NME1, colon cancer]</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>[L-Asparagine, SLC1A5, VEGFA, colon cancer]</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>[L-Asparagine, ASRGL1, GDF15, colon cancer]</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5</th>\n",
       "      <td>[L-Asparagine, SLC1A5, FZD5, colon cancer]</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>6</th>\n",
       "      <td>[L-Asparagine, ASNS, CCNB1, colon cancer]</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>7</th>\n",
       "      <td>[L-Asparagine, ASNS, HSF1, colon cancer]</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>8</th>\n",
       "      <td>[L-Asparagine, SLC38A3, VEGFA, colon cancer]</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>9</th>\n",
       "      <td>[L-Asparagine, SLC1A5, OXCT1, colon cancer]</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>10</th>\n",
       "      <td>[L-Asparagine, ASRGL1, AKR7A2, colon cancer]</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>11</th>\n",
       "      <td>[L-Asparagine, ASRGL1, CDK7, colon cancer]</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>12</th>\n",
       "      <td>[L-Asparagine, SLC38A3, RBM28, colon cancer]</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>13</th>\n",
       "      <td>[L-Asparagine, ASNS, CEBPB, colon cancer]</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>14</th>\n",
       "      <td>[L-Asparagine, SLC1A5, GNAI1, colon cancer]</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>15</th>\n",
       "      <td>[L-Asparagine, ASNS, G3BP1, colon cancer]</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>16</th>\n",
       "      <td>[L-Asparagine, SLC1A5, APC, colon cancer]</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>17</th>\n",
       "      <td>[L-Asparagine, SLC1A5, RUVBL1, colon cancer]</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>18</th>\n",
       "      <td>[L-Asparagine, SLC1A5, SRC, colon cancer]</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>19</th>\n",
       "      <td>[L-Asparagine, ASRGL1, PSMG1, colon cancer]</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>20</th>\n",
       "      <td>[L-Asparagine, SLC1A5, G3BP1, colon cancer]</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>21</th>\n",
       "      <td>[L-Asparagine, ASRGL1, HSD17B10, colon cancer]</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>22</th>\n",
       "      <td>[L-Asparagine, NARS, TP53, colon cancer]</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>23</th>\n",
       "      <td>[L-Asparagine, SLC38A3, EZH2, colon cancer]</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>24</th>\n",
       "      <td>[L-Asparagine, ASNS, BRAF, colon cancer]</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                          [n in nodes(p) | n.name]\n",
       "0      [L-Asparagine, ASRGL1, SSBP2, colon cancer]\n",
       "1    [L-Asparagine, SLC38A3, PLXNA1, colon cancer]\n",
       "2       [L-Asparagine, ASRGL1, NME1, colon cancer]\n",
       "3      [L-Asparagine, SLC1A5, VEGFA, colon cancer]\n",
       "4      [L-Asparagine, ASRGL1, GDF15, colon cancer]\n",
       "5       [L-Asparagine, SLC1A5, FZD5, colon cancer]\n",
       "6        [L-Asparagine, ASNS, CCNB1, colon cancer]\n",
       "7         [L-Asparagine, ASNS, HSF1, colon cancer]\n",
       "8     [L-Asparagine, SLC38A3, VEGFA, colon cancer]\n",
       "9      [L-Asparagine, SLC1A5, OXCT1, colon cancer]\n",
       "10    [L-Asparagine, ASRGL1, AKR7A2, colon cancer]\n",
       "11      [L-Asparagine, ASRGL1, CDK7, colon cancer]\n",
       "12    [L-Asparagine, SLC38A3, RBM28, colon cancer]\n",
       "13       [L-Asparagine, ASNS, CEBPB, colon cancer]\n",
       "14     [L-Asparagine, SLC1A5, GNAI1, colon cancer]\n",
       "15       [L-Asparagine, ASNS, G3BP1, colon cancer]\n",
       "16       [L-Asparagine, SLC1A5, APC, colon cancer]\n",
       "17    [L-Asparagine, SLC1A5, RUVBL1, colon cancer]\n",
       "18       [L-Asparagine, SLC1A5, SRC, colon cancer]\n",
       "19     [L-Asparagine, ASRGL1, PSMG1, colon cancer]\n",
       "20     [L-Asparagine, SLC1A5, G3BP1, colon cancer]\n",
       "21  [L-Asparagine, ASRGL1, HSD17B10, colon cancer]\n",
       "22        [L-Asparagine, NARS, TP53, colon cancer]\n",
       "23     [L-Asparagine, SLC38A3, EZH2, colon cancer]\n",
       "24        [L-Asparagine, ASNS, BRAF, colon cancer]"
      ]
     },
     "execution_count": 22,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "run_query(\"\"\"\n",
    "MATCH (c:Compound {name: \"L-Asparagine\"}),(d:Disease {name:\"colon cancer\"})\n",
    "WITH c,d\n",
    "MATCH p=AllShortestPaths((c)-[r:binds|regulates|interacts|upregulates|downregulates|associates*1..4]-(d))\n",
    "RETURN [n in nodes(p) | n.name] LIMIT 25\n",
    "\"\"\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "329cd10b",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
